Explaining Text Classification

from explainer.explainers import feature_attributions_explainer, metrics_explainer
import numpy as np
from sklearn import datasets

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['alt.atheism','comp.graphics','rec.motorcycles','sci.space','talk.politics.misc']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping
(2720,
 1810,
 array([0, 1, 2, 3, 4]),
 {0: 'alt.atheism',
  1: 'comp.graphics',
  2: 'rec.motorcycles',
  3: 'sci.space',
  4: 'talk.politics.misc'})
print(Y_test)
[2 3 0 ... 3 2 3]

Vectorize Text Data

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train, X_test = X_train.toarray(), X_test.toarray()

X_train.shape, X_test.shape
((2720, 50000), (1810, 50000))

Define the Model

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

def create_model():
    return Sequential([
                        layers.Input(shape=X_train.shape[1:]),
                        layers.Dense(128, activation="relu"),
                        layers.Dense(64, activation="relu"),
                        layers.Dense(len(classes), activation="softmax"),
                    ])

model = create_model()

model.summary()

Compile and Train Model

model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train, Y_train, batch_size=256, epochs=5, validation_data=(X_test, Y_test))

Evaluate Model Performance

from sklearn.metrics import accuracy_score, classification_report

train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
Hide code cell output
 1/85 [..............................] - ETA: 17s

 7/85 [=>............................] - ETA: 0s 

13/85 [===>..........................] - ETA: 0s

21/85 [======>.......................] - ETA: 0s

28/85 [========>.....................] - ETA: 0s

36/85 [===========>..................] - ETA: 0s

44/85 [==============>...............] - ETA: 0s

53/85 [=================>............] - ETA: 0s

61/85 [====================>.........] - ETA: 0s

70/85 [=======================>......] - ETA: 0s

80/85 [===========================>..] - ETA: 0s

85/85 [==============================] - 1s 7ms/step
 1/57 [..............................] - ETA: 5s

 8/57 [===>..........................] - ETA: 0s

15/57 [======>.......................] - ETA: 0s

23/57 [===========>..................] - ETA: 0s

32/57 [===============>..............] - ETA: 0s

40/57 [====================>.........] - ETA: 0s

47/57 [=======================>......] - ETA: 0s

54/57 [===========================>..] - ETA: 0s

57/57 [==============================] - 1s 7ms/step
Train Accuracy : 1.000
Test  Accuracy : 0.948

Classification Report : 
                    precision    recall  f1-score   support

       alt.atheism       0.96      0.93      0.95       319
     comp.graphics       0.94      0.96      0.95       389
   rec.motorcycles       0.97      0.99      0.98       398
         sci.space       0.94      0.93      0.93       394
talk.politics.misc       0.93      0.91      0.92       310

          accuracy                           0.95      1810
         macro avg       0.95      0.95      0.95      1810
      weighted avg       0.95      0.95      0.95      1810
# one-hot-encode clasess
oh_Y_test = np.eye(len(classes))[Y_test]

cm = metrics_explainer['confusionmatrix'](oh_Y_test, test_preds, selected_categories)
cm.visualize()
print(cm.report)
                    precision    recall  f1-score   support

       alt.atheism       0.96      0.93      0.95       319
     comp.graphics       0.94      0.96      0.95       389
   rec.motorcycles       0.97      0.99      0.98       398
         sci.space       0.94      0.93      0.93       394
talk.politics.misc       0.93      0.91      0.92       310

          accuracy                           0.95      1810
         macro avg       0.95      0.95      0.95      1810
      weighted avg       0.95      0.95      0.95      1810
../../_images/e1b2301fc995640fb10605c29bcc70ec6942ee912e57575c857faeebe13694f4.png
plotter = metrics_explainer['plot'](oh_Y_test, test_preds, selected_categories)
plotter.pr_curve()
plotter.roc_curve()
import re

X_batch_text = X_test_text[1:3]
X_batch = X_test[1:3]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

preds_proba = model.predict(X_batch)
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[1:3]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))
Samples : 
['From', 'prb', 'access', 'digex', 'net', 'Pat', 'Subject', 'Re', 'Near', 'Miss', 'Asteroids', 'Q', 'Organization', 'Express', 'Access', 'Online', 'Communications', 'Greenbelt', 'MD', 'USA', 'Lines', '4', 'Distribution', 'sci', 'NNTP', 'Posting', 'Host', 'access', 'digex', 'net', 'TRry', 'the', 'SKywatch', 'project', 'in', 'Arizona', 'pat', '']

['From', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Mike', 'Cobb', 'Subject', 'Science', 'and', 'theories', 'Organization', 'University', 'of', 'Illinois', 'at', 'Urbana', 'Lines', '19', 'As', 'per', 'various', 'threads', 'on', 'science', 'and', 'creationism', 'I', 've', 'started', 'dabbling', 'into', 'a', 'book', 'called', 'Christianity', 'and', 'the', 'Nature', 'of', 'Science', 'by', 'JP', 'Moreland', 'A', 'question', 'that', 'I', 'had', 'come', 'from', 'one', 'of', 'his', 'comments', 'He', 'stated', 'that', 'God', 'is', 'not', 'necessarily', 'a', 'religious', 'term', 'but', 'could', 'be', 'used', 'as', 'other', 'scientific', 'terms', 'that', 'give', 'explanation', 'for', 'events', 'or', 'theories', 'without', 'being', 'a', 'proven', 'scientific', 'fact', 'I', 'think', 'I', 'got', 'his', 'point', 'I', 'can', 'quote', 'the', 'section', 'if', 'I', 'm', 'being', 'vague', 'The', 'examples', 'he', 'gave', 'were', 'quarks', 'and', 'continental', 'plates', 'Are', 'there', 'explanations', 'of', 'science', 'or', 'parts', 'of', 'theories', 'that', 'are', 'not', 'measurable', 'in', 'and', 'of', 'themselves', 'or', 'can', 'everything', 'be', 'quantified', 'measured', 'tested', 'etc', 'MAC', 'Michael', 'A', 'Cobb', 'and', 'I', 'won', 't', 'raise', 'taxes', 'on', 'the', 'middle', 'University', 'of', 'Illinois', 'class', 'to', 'pay', 'for', 'my', 'programs', 'Champaign', 'Urbana', 'Bill', 'Clinton', '3rd', 'Debate', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Nobody', 'can', 'explain', 'everything', 'to', 'anybody', 'G', 'K', 'Chesterton', '']


1/1 [==============================] - ETA: 0s

1/1 [==============================] - 0s 40ms/step
Actual    Target Values : ['sci.space', 'alt.atheism']
Predicted Target Values : ['sci.space', 'alt.atheism']
Predicted Probabilities : [0.92902756 0.78050315]

SHAP Partition Explainer

Visualize SHAP Values Correct Predictions

def make_predictions(X_batch_text):
    X_batch = vectorizer.transform(X_batch_text).toarray()
    preds = model.predict(X_batch)
    return preds

partition_explainer = feature_attributions_explainer.partitionexplainer(make_predictions, r"\W+", selected_categories)(X_batch_text)

Text Plot

partition_explainer.visualize()


[0]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc


0.50.30.10.70.90.1642390.164239base value0.008652730.00865273falt.atheism(inputs)0.016 Arizona. 0.007 SKywatch 0.006 TRry 0.0 Re: 0.0 Miss 0.0 the -0.017 project -0.015 pat -0.012 digex. -0.011 Pat) -0.011 access. -0.01 sci -0.009 prb@ -0.008 Express -0.008 Access -0.008 net ( -0.008 Communications, -0.007 digex. -0.007 Online -0.007 net -0.007 access. -0.006 Asteroids ( -0.006 USA -0.005 Near -0.005 Distribution: -0.005 MD -0.004 Greenbelt, -0.002 Organization: -0.002 Subject: -0.002 From: -0.002 NNTP- -0.001 Lines: -0.001 in -0.001 Posting- -0.0 Host: -0.0 4 -0.0 Q)
inputs
-0.002
From:
-0.009
prb@
-0.007
access.
-0.007
digex.
-0.008
net (
-0.011
Pat)
-0.002
Subject:
0.0
Re:
-0.005
Near
0.0
Miss
-0.006
Asteroids (
-0.0
Q)
-0.002
Organization:
-0.008
Express
-0.008
Access
-0.007
Online
-0.008
Communications,
-0.004
Greenbelt,
-0.005
MD
-0.006
USA
-0.001
Lines:
-0.0
4
-0.005
Distribution:
-0.01
sci
-0.002
NNTP-
-0.001
Posting-
-0.0
Host:
-0.011
access.
-0.012
digex.
-0.007
net
0.006
TRry
0.0
the
0.007
SKywatch
-0.017
project
-0.001
in
0.016
Arizona.
-0.015
pat


[1]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc


0.50.30.10.70.90.1642390.164239base value0.7805030.780503falt.atheism(inputs)0.069 alexia.lis.uiuc. 0.061 Debate cobb@ 0.052 alexia. 0.051 lis. 0.04 edu Nobody can explain 0.035 necessarily a religious term, 0.031 that God 0.027 his comments. He stated 0.027 cobb@ 0.026 is not 0.026 but 0.025 creationism, I' 0.023 other scientific terms that 0.022 point -- I can quote 0.021 called Christianity and 0.021 give explanation for events 0.02 ve started 0.019 think I got his 0.017 From: 0.016 proven scientific fact. I 0.016 dabbling into a book 0.016 could be used as 0.016 or 0.014 Cobb "...and I won' 0.014 Mike Cobb) 0.011 examples he 0.009 the Nature of Science 0.009 Urbana -Bill 0.008 Clinton 3rd 0.008 t raise taxes on 0.008 question that I had come from one of 0.008 Subject: Science and 0.007 theories, without being a 0.007 the section if I'm being vague. The 0.005 Champaign- 0.005 gave were quarks 0.004 uiuc.edu ( 0.003 programs." 0.002 for 0.002 my -0.018 measured, tested, etc.? -0.015 theories Organization: -0.015 per various -0.014 University of -0.013 19 As -0.012 and continental plates. Are there -0.011 Illinois class to pay -0.01 Illinois at -0.008 and -0.008 Urbana Lines: -0.008 threads on science -0.008 MAC -- **************************************************************** Michael -0.007 A. -0.007 by JP Moreland. A -0.006 K.Chesterton -0.005 theories that -0.004 anybody. G. -0.004 and of themselves, or can everything be quantified, -0.003 parts of -0.003 explanations of science or -0.002 everything to -0.002 the middle University of -0.0 are not measurable in
inputs
0.017
From:
0.027
cobb@
0.052
alexia.
0.051
lis.
0.004 / 2
uiuc.edu (
0.014 / 2
Mike Cobb)
0.008 / 3
Subject: Science and
-0.015 / 2
theories Organization:
-0.014 / 2
University of
-0.01 / 2
Illinois at
-0.008 / 2
Urbana Lines:
-0.013 / 2
19 As
-0.015 / 2
per various
-0.008 / 3
threads on science
-0.008
and
0.025 / 2
creationism, I'
0.02 / 2
ve started
0.016 / 4
dabbling into a book
0.021 / 3
called Christianity and
0.009 / 4
the Nature of Science
-0.007 / 4
by JP Moreland. A
0.008 / 8
question that I had come from one of
0.027 / 4
his comments. He stated
0.031 / 2
that God
0.026 / 2
is not
0.035 / 4
necessarily a religious term,
0.026
but
0.016 / 4
could be used as
0.023 / 4
other scientific terms that
0.021 / 4
give explanation for events
0.016
or
0.007 / 4
theories, without being a
0.016 / 4
proven scientific fact. I
0.019 / 4
think I got his
0.022 / 4
point -- I can quote
0.007 / 8
the section if I'm being vague. The
0.011 / 2
examples he
0.005 / 3
gave were quarks
-0.012 / 5
and continental plates. Are there
-0.003 / 4
explanations of science or
-0.003 / 2
parts of
-0.005 / 2
theories that
-0.0 / 4
are not measurable in
-0.004 / 8
and of themselves, or can everything be quantified,
-0.018 / 3
measured, tested, etc.?
-0.008 / 2
MAC -- **************************************************************** Michael
-0.007
A.
0.014 / 4
Cobb "...and I won'
0.008 / 4
t raise taxes on
-0.002 / 4
the middle University of
-0.011 / 4
Illinois class to pay
0.002
for
0.002
my
0.003
programs."
0.005
Champaign-
0.009 / 2
Urbana -Bill
0.008 / 2
Clinton 3rd
0.061 / 2
Debate cobb@
0.069 / 3
alexia.lis.uiuc.
0.04 / 4
edu Nobody can explain
-0.002 / 2
everything to
-0.004 / 2
anybody. G.
-0.006 / 2
K.Chesterton

Bar Plots

Bar Plot 1

shap = partition_explainer.shap
shap_values = partition_explainer.shap_values

shap.plots.bar(partition_explainer.shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)
Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
../../_images/6bb6bb170450fde0ff7575b83be501543ea8ac5908fc45f05f68b8d38b4fe0f8.png

Bar Plot 2

shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/d5a0610a952a009dbaf275586b38ed3f4a5e3d62fab0dd81f6a960d717efecb1.png

Bar Plot 3

shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/0f30253eeaefe4856d7924e37733a6d1f0fde1c7845608273ad9189c384a86ff.png

Bar Plot 4

shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/9b4f5d9c3231b10042a9e03673d3bf0d2b0badfeab86a9ddf2569e573bd825f4.png

Waterfall Plots

Waterfall Plot 1

shap.waterfall_plot(shap_values[0][:, selected_categories[preds[0]]], max_display=15)
../../_images/327060f328d90fdc5c7f7b3e2761e866b23e06817027ef8573d665b5485681db.png

Waterfall Plot 2

shap.waterfall_plot(shap_values[1][:, selected_categories[preds[1]]], max_display=15)
../../_images/2222178232bf6d31e7ef13039d8a0e6b65f00321408dc98cc52802c21fbc2d79.png

Force Plot

import re
tokens = re.split("\W+", X_batch_text[0].lower())
shap.initjs()
shap.force_plot(shap_values.base_values[0][preds[0]], shap_values[0][:, preds[0]].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.